{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a CNN on MNIST using evotorch\n",
    "\n",
    "This example demonstrates the application of the `evotorch.neuroevolution.SupervisedNE` `Problem` class to training a CNN on MNIST. This example follows set-up described in the recent DeepMind paper [1].\n",
    "\n",
    "Note that to use this example, please ensure that [torchvision](https://pytorch.org/vision/stable/index.html) is installed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up the `Problem` class\n",
    "First we will define the model. For this example, we will use the 'MNIST-30k' model from the paper, which is defined below. Note that Table 1 has a typo; the second convolution should have a 5x5 kernel, rather than a 2x2 kernel. This gives the number of parameters the authors listed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from evotorch.neuroevolution.net import count_parameters\n",
    "\n",
    "class MNIST30K(nn.Module):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        # The first convolution uses a 5x5 kernel and has 16 filters\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size = 5, stride = 1, padding=2)\n",
    "        # Then max pooling is applied with a kernel size of 2\n",
    "        self.pool1 = nn.MaxPool2d(kernel_size = 2)\n",
    "        # The second convolution uses a 5x5 kernel and has 32 filters\n",
    "        self.conv2 = nn.Conv2d(16, 32, kernel_size = 5, stride = 1, padding = 2)\n",
    "        # Another max pooling is applied with a kernel size of 2\n",
    "        self.pool2 = nn.MaxPool2d(kernel_size = 2)\n",
    "        \n",
    "        # Apply layer normalization after the second pool\n",
    "        self.norm = nn.LayerNorm(1568, elementwise_affine=False)\n",
    "        \n",
    "        # A final linear layer maps outputs to the 10 target classes\n",
    "        self.out = nn.Linear(1568, 10)\n",
    "        \n",
    "        # All activations are ReLU\n",
    "        self.act = nn.ReLU()\n",
    "        \n",
    "    def forward(self, data: torch.Tensor) -> torch.Tensor:\n",
    "        # Apply the first conv + pool\n",
    "        data = self.pool1(self.act(self.conv1(data)))\n",
    "        # Apply the second conv + pool\n",
    "        data = self.pool2(self.act(self.conv2(data)))\n",
    "        \n",
    "        # Apply layer norm\n",
    "        data = self.norm(data.flatten(start_dim = 1))\n",
    "        \n",
    "        # Flatten and apply the output linear layer\n",
    "        data = self.out(data)\n",
    "        \n",
    "        return data\n",
    "        \n",
    "network = MNIST30K()\n",
    "print(f'Network has {count_parameters(network)} parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets pull the dataset (to use with standard transforms). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])\n",
    "train_dataset = datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transform)\n",
    "test_dataset = datasets.MNIST('../data', train=False,\n",
    "                   transform=transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to create a custom problem class. The below is configured to use 4 actors, and divide the available GPUs between them. You can scale this up to dozens or even hundreds of CPUs and GPUs on a `ray` cluster simply by modifying the `num_actors` parameter. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from evotorch.neuroevolution import SupervisedNE\n",
    "\n",
    "mnist_problem = SupervisedNE(\n",
    "    train_dataset,  # Using the dataset specified earlier\n",
    "    MNIST30K,  # Training the MNIST30K module designed earlier\n",
    "    nn.CrossEntropyLoss(),  # Minimizing CrossEntropyLoss\n",
    "    minibatch_size = 1024,  # With a minibatch size of 1024\n",
    "    common_minibatch = True,  # Always using the same minibatch across all solutions on an actor\n",
    "    num_actors = 4,  # The total number of CPUs used\n",
    "    num_gpus_per_actor = 'max',  # Dividing all available GPUs between the 4 actors\n",
    "    subbatch_size = 50,  # Evaluating solutions in sub-batches of size 50 ensures we won't run out of GPU memory for individual workers\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "Now we can set up the searcher.\n",
    "\n",
    "In the paper, they used SNES with, effectively, default parameters, and standard deviation 1. The authors achieved 98%+ with only a population size of 1k, but this value can be pushed further. Note that by using the `distributed = True` keyword argument, we obtain semi-updates from the individual actors which are averaged.\n",
    "\n",
    "In our example, we use PGPE with a population size of 3200. Hyperparameter configuration can be seen below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from evotorch.algorithms import PGPE\n",
    "\n",
    "searcher = PGPE(\n",
    "    mnist_problem,\n",
    "    radius_init=2.25, # Initial radius of the search distribution\n",
    "    center_learning_rate=1e-2, # Learning rate used by adam optimizer\n",
    "    stdev_learning_rate=0.1, # Learning rate for the standard deviation\n",
    "    popsize=3200,\n",
    "    distributed=True, # Gradients are computed locally at actors and averaged\n",
    "    optimizer=\"adam\", # Using the adam optimizer\n",
    "    ranking_method=None, # No rank-based fitness shaping is used\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create some loggers. We'll run evolution for quite a long time, so it's worth reducing the log frequency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from evotorch.logging import StdOutLogger, PandasLogger\n",
    "stdout_logger = StdOutLogger(searcher, interval = 1)\n",
    "pandas_logger = PandasLogger(searcher, interval = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running evolution for 400 generations (note that in the paper, it was 10k generations)..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "searcher.run(400)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can visualize the progress:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pandas_logger.to_dataframe().mean_eval.plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And of course, it is worth while to measure the test performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#net = mnist_problem.parameterize_net(searcher.status['center']).cpu()\n",
    "net = mnist_problem.make_net(searcher.status[\"center\"]).cpu()\n",
    "\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "net.eval()\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 256, shuffle = False)\n",
    "test_loss = 0\n",
    "correct = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for data, target in test_loader:\n",
    "        output = net(data)\n",
    "        test_loss += loss(output, target).item() * data.shape[0]\n",
    "        pred = output.data.max(1, keepdim=True)[1]\n",
    "        correct += pred.eq(target.data.view_as(pred)).sum()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "print('Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "test_loss, correct, len(test_loader.dataset),\n",
    "100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mnist_problem.kill_actors()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### References\n",
    "\n",
    "[1] Lenc, Karel, et al. [\"Non-differentiable supervised learning with evolution strategies and hybrid methods.\"](https://www.deepmind.com/publications/non-differentiable-supervised-learning-with-evolution-strategies-and-hybrid-methods) arXiv preprint arXiv:1906.03139 (2019).\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
